<docs lang="markdown">
[TODO: write documentation for this plugin.]
</docs>

<config lang="json">
{
    "name": "UC2_tools_janelia",
    "type": "native-python",
    "version": "0.1.0",
    "description": "[This is the backend to process any data which is coming from the cellphone.]",
    "tags": [],
    "ui": "",
    "cover": "",
    "inputs": null,
    "outputs": null,
    "flags": [],
    "icon": "extension",
    "api_version": "0.1.6",
    "env": "",
    "permissions": [],
    "requirements": ["numpy", "matplotlib", "pillow", "scikit-image", "scipy", "requests"],
    "dependencies": []
}
</config>

<script lang="python">
from imjoy import api
import numpy as np
from scipy import ndimage as ndi
from skimage.morphology import watershed
from skimage.feature import peak_local_max
from skimage import color
from skimage.filters import threshold_otsu
from skimage import morphology
from skimage.morphology import disk
from scipy import ndimage as ndi
from skimage import feature, measure
from PIL import Image
import base64

import matplotlib as mpl
mpl.use('Agg')
from matplotlib import pyplot as plt

plt.ioff()

class ImJoyPlugin():

    def setup(self):
        self.window = None
        api.log('initialized')

    def getEngineURL(self):
        return api.ENGINE_URL

    def setParam1(self, param):
        self.param1 = param

    def setURL(self, url):
        self.url = url

    async def process(self):
        await api.alert('processing with param1='+ str(self.param1))
        
        # Taken from http://opensciencecafe.org/2016/01/counting-change-image-analysis-python/
        # await api.alert('Processing the image')
        
        # load the image which was loaded through the GUI
        # await api.alert('Loading image from: '+self.url)
        myimage = np.array(Image.open(self.url)) #myimage = plt.imread()
        myimage_gray = color.rgb2gray(myimage)

        # Adaptive Thresholding the image 
        thresh = threshold_otsu(myimage_gray, nbins=5)
        thresh_im = myimage_gray > thresh

        # Cleaning up the mask
        no_small = morphology.remove_small_objects(thresh_im, min_size=150)
        neurons = morphology.binary_closing(no_small,disk(3))

        # Masking
        myimage[neurons==False] = 0
        myimage_gray[neurons==False] = 0

        # Watershed Segmentation
        distance_im = ndi.distance_transform_edt(neurons)
        
        peaks_im = feature.peak_local_max(distance_im, indices=False)

        markers_im = measure.label(peaks_im)
        labelled_neurons = morphology.watershed(-distance_im, markers_im, mask=neurons)
        num_neurons = len(np.unique(labelled_neurons))-1  # subtract 1 b/c background is labelled 0
        await api.alert('number of neurons: %i' % num_neurons)

        # Plot the result and save as png in order to display it
        fig1, ax = plt.subplots(num='Labelled Neurons')
        plt.imshow(labelled_neurons,cmap='jet')
        name_plot = 'MyNeurons.png'
        plt.savefig(name_plot,dpi=300)
        with open(name_plot, 'rb') as f:
            data = f.read()
            result = base64.b64encode(data).decode('ascii')
            imgurl = 'data:image/png;base64,' + result
            data = {"src": imgurl}
            data_plot = {
                'name':'Plot charts: show png',
                'type':'imjoy/image',
                'w':16, 'h':20,
                'data':data}
        ## Check if window was defined
        if self.window is None:
            self.window = await api.createWindow(data_plot)
            print(f'Window created')
        else:
            print(f'Update window.')
            try:
                await self.window.run(data=data)
            except:
                self.window = await api.createWindow(data_plot)
                print(f'Could not print to old window. New window created.')
        await api.showMessage('done!')
        

    def imshow_overlay(im, mask, alpha=0.5, color='red', **kwargs):
        """Show semi-transparent red mask over an image"""
        mask = mask > 0
        mask = np.ma.masked_where(~mask, mask)        
        plt.imshow(im, **kwargs)
        plt.imshow(mask, alpha=alpha, cmap=ListedColormap([color]))

    def my_imshow(im, title=None, **kwargs):
        if 'cmap' not in kwargs:
            kwargs['cmap'] = 'gray'
        plt.figure()
        plt.imshow(im, interpolation='none', **kwargs)
        if title:
            plt.title(title)
        plt.axis('off')
        
    def run(self, ctx):
        print('Hello World')

api.export(ImJoyPlugin())
</script>
    